In [7]:
%load_ext autoreload
%autoreload 2
In [8]:
import pickle

from tqdm import tqdm

import numpy as np
import scipy.interpolate

from sklearn.preprocessing import KBinsDiscretizer

import pandas as pd

import matplotlib
import matplotlib.pyplot as plt
import seaborn as sns

from modules.utils.general_utils.utilities import group_wise_binning

from plotly.subplots import make_subplots
import plotly.graph_objects as go
In [9]:
def interpolate_paths(z, x, y, c, rep_id):
    """Interpolate lines
    """
    INTERP_KIND = {2:"linear", 3:"quadratic", 4:"cubic"}
    
    consecutive_year_blocks = np.where(np.diff(z) != 1)[0] + 1
    z_blocks = np.split(z, consecutive_year_blocks)
    x_blocks = np.split(x, consecutive_year_blocks)
    y_blocks = np.split(y, consecutive_year_blocks)
    c_blocks = np.split(c, consecutive_year_blocks)

    paths = []

    for block_idx, zs in enumerate(z_blocks):

        if len(zs) > 1:
            kind = INTERP_KIND.get(len(zs), "cubic")
        else:
            paths.append(
                (zs, x_blocks[block_idx], y_blocks[block_idx], c_blocks[block_idx])
            )
            continue

        z = np.linspace(np.min(zs), np.max(zs), 100)
        x = scipy.interpolate.interp1d(zs, x_blocks[block_idx], kind=kind)(z)
        y = scipy.interpolate.interp1d(zs, y_blocks[block_idx], kind=kind)(z)
        c = scipy.interpolate.interp1d(zs, c_blocks[block_idx], kind=kind)(z)

        paths.append((z, x, y, c))

    return paths
In [10]:
with (open('results\\saved_data_containers\\melchior.pkl', 'rb')) as container:
    DATA_CONTAINER = pickle.load(container)
    
predictions = DATA_CONTAINER['prediction_ds']['tar_activity']
contexts = DATA_CONTAINER['context']

predictions = [predictions[i] for i in range(5)]
contexts = [contexts[i] for i in range(5)]
predictions = np.hstack(group_wise_binning(predictions, n_bins=100, grouper=contexts))
In [11]:
df = pd.read_csv('results\\saved_dim_reduction\\melchior_eng_emb_temporal.csv')

df = pd.read_csv('results\\saved_dim_reduction\\melchior_eng_emb_temporal.csv')
df['Future Session Activity'] = predictions
df = df[df['context'] == 6]

users = df.groupby(['user_id'])['session'].max() +1
users = users.reset_index()
users = users[users['session'] == 4]['user_id'].values

df = df[df['user_id'].isin(users)]
In [12]:
discretizer = KBinsDiscretizer(n_bins=9, encode='ordinal')

variability_rank = df.groupby('user_id')['Future Session Activity'].agg(
    lambda x: np.var(x.values)).reset_index()
variability_rank['rank'] = discretizer.fit_transform(variability_rank['Future Session Activity'].values.reshape((-1, 1)))

sns.histplot(variability_rank['rank'].values)
Out[12]:
<matplotlib.axes._subplots.AxesSubplot at 0x28192917588>
In [14]:
zoom=3.5
fig = make_subplots(
    rows=1, 
    cols=3, 
    specs=[
        [{'type': 'scatter3d'}, {'type': 'scatter3d'}, {'type': 'scatter3d'}]
    ],
    subplot_titles=(
        'Low Variability',
        'Medium Variability',  
        'High Variability'
    ),
    horizontal_spacing = 0.01,
    vertical_spacing = 0.05,
    shared_xaxes=True,
    shared_yaxes=True
)

locations = [
    (1, 1),
    (1, 2),
    (1, 3)
]
for index, rank in enumerate([0, 4, 8]):
    
    location = locations[index]
    
    unique_ids = variability_rank[variability_rank['rank'] == rank]['user_id'].values
    
    print(len(unique_ids))
    
    unique_ids = np.random.choice(
        unique_ids, 
        min(len(unique_ids), 700), 
        replace=False
    )
    traces = []
    
    for unique_id in unique_ids:
        
        z = df.session[df.user_id == unique_id].values
        x = df.UMAP_1[df.user_id == unique_id].values
        y = df.UMAP_2[df.user_id == unique_id].values
        c = df['Future Session Activity'][df.user_id == unique_id]

        for z, x, y, c in interpolate_paths(z, x, y, c, unique_id):

            trace = go.Scatter3d(
                x=x, y=z, z=y,
                mode='lines',
                line=dict(
                    color=c,
                    cmin=0,
                    cmid=50,
                    cmax=100,
                    cauto=False,
                    colorscale='RdBu',
                    colorbar=dict(),
                    width=0.6,
                ),
                opacity=0.6,
            )
            fig.add_trace(trace, row=location[0], col=location[1])

fig.update_layout(
    width=1050,
    height=500,
    margin=dict(r=1, l=1),
    showlegend=False,
    autosize=False,
    template="plotly_white",
)
fig.update_layout(
    scene_aspectmode='manual',
    scene_aspectratio=dict(x=1, y=3, z=1),
    scene_camera=dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0.5, y=0.5, z=0),
        eye=dict(x=(0.5)*zoom, y=(0.8)*zoom, z=(0.75)*zoom)
    )
)
fig.update_layout(
    scene2_aspectmode='manual',
    scene2_aspectratio=dict(x=1, y=3, z=1),
    scene2_camera=dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0.5, y=0.5, z=0),
        eye=dict(x=(0.5)*zoom, y=(0.8)*zoom, z=(0.75)*zoom)
    )
)
fig.update_layout(
    scene3_aspectmode='manual',
    scene3_aspectratio=dict(x=1, y=3, z=1),
    scene3_camera=dict(
        up=dict(x=0, y=0, z=1),
        center=dict(x=0.5, y=0.5, z=0),
        eye=dict(x=(0.5)*zoom, y=(0.8)*zoom, z=(0.75)*zoom)
    )
)

fig.update_scenes(
    xaxis_title_text='UMAP 1',  
    zaxis_title_text='UMAP 2',  
    yaxis_title_text=r"$t$",
    yaxis = dict(
        tickmode = 'array',
        tickvals = [0, 1, 2, 3],
        ticktext = [1, 2, 3, 4]
    )
)

fig.show()
1395
1402
1395
In [ ]: